import logging
import numpy as np
import tensorflow as tf
import time
import datetime
import uuid


class Model:
    def __init__(self, model_path):
        self.model_path = model_path
        self.label_name_dict = {
            0: "空台面",
            1: "工件传入",
            2: "拆卸顶盖",
            3: "待装配",
            4: "检测左侧两螺母孔",
            5: "安装左侧一号螺丝",
            6: "安装左侧二号螺丝",
            7: "拧紧左侧一号螺丝",
            8: "拧紧左侧二号螺丝",
            9: "标记左侧完工螺丝",
            10: "装配右下侧工件",
            11: "装配横杠",
            12: "拧横杠螺丝",
            13: "标记完工横杠",
            14: "完工运走",
        }
        print('TensowFlow Version:', tf.__version__)

    def create_net(self):
        # datas, labels =
        num_classes = len(self.label_name_dict)
        # placeholder
        datas_placeholder = tf.placeholder(tf.float32, [None, 36, 62, 3])
        dropout_placeholdr = tf.placeholder(tf.float32)
        # CNN
        conv0 = tf.layers.conv2d(datas_placeholder, 20, 5, activation=tf.nn.relu)
        pool0 = tf.layers.max_pooling2d(conv0, [2, 2], [2, 2])
        conv1 = tf.layers.conv2d(pool0, 40, 4, activation=tf.nn.relu)
        pool1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
        flatten = tf.layers.flatten(pool1)
        fc = tf.layers.dense(flatten, 400, activation=tf.nn.relu)
        dropout_fc = tf.layers.dropout(fc, dropout_placeholdr)
        logits = tf.layers.dense(dropout_fc, num_classes)
        predicted_labels = tf.arg_max(logits, 1)
        return predicted_labels, dropout_placeholdr, datas_placeholder

    def run(self):
        predicted_labels, dropout_placeholdr, datas_placeholder = self.create_net()
        saver = tf.train.Saver()
        labels = []
        timer = time.time()
        label_state = -1
        with tf.Session() as sess:
            saver.restore(sess, self.model_path)
            while True:
                datas = self.bind_function()
                if datas is None:
                    logging.warning("Recived Data Is Empty")
                    break
                else:
                    test_feed_dict = {
                        datas_placeholder: datas,
                        dropout_placeholdr: 0
                    }
                    predicted_label = sess.run(predicted_labels, feed_dict=test_feed_dict)
                    label = predicted_label[0]
                    _label = self.label_name_dict[predicted_label[0]]
                    labels.append(label)
                    end_time = time.time()
                    if int(end_time - timer) >= 1:
                        timer = end_time
                        label = max(set(labels), key=labels.count)
                        if label != label_state:
                            image_name = "%s.jpg" % (uuid.uuid4())
                            print('写入数据库：', label, "->", _label, '->', image_name)
                            self.call_back_function(label, image_name)
                            label_state = label
                        print(label, "->", _label)
                        labels.clear()

    def call_back_function(self, label, image_name):
        return None

    def bind_function(self):
        return np.zeros([1, 36, 62, 3])
